Skip to content

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713

Open
cspades wants to merge 7 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp
Open

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
cspades wants to merge 7 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp

Conversation

@cspades
Copy link
Member

@cspades cspades commented Feb 26, 2026

Summary

  • Support (H/F)SDP2 x TP strided sharding, and DTensor FP8 parameters for Torch DCP checkpointing, across all TransformerEngineBaseModule(s).
    • Except GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules under transformer_engine.pytorch.modules are supported.
    • FusibleOperation support is also a WIP, except for LayerNorm or RMSNorm which are TE modules.
  • Associated with BioNeMo-Recipes Llama3 TP: Enable TransformerEngine-backed Tensor Parallelism with Llama3. bionemo-framework#1483
    • Notably, TransformerEngine TP can be easily mixed with DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we use DTensor-based TP on the torch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to the torch.nn.Embedding, which is why we do not need to call set_device_mesh for the LM head!
  • Credit to @pstjohn for coming up with this idea!

Usage / Documentation

(tp_mesh and weight_mesh can also be passed in TEModule.__init__.)

    def set_device_mesh(
        self,
        tp_mesh: Optional[DeviceMesh] = None,
        weight_mesh: Optional[DeviceMesh] = None,
    ) -> None:
        """
        Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor
        depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2.

        TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless
        integration with Torch DCP checkpointing. This method should only be invoked when
        using DTensor parameters, e.g. when using FSDP2 or DCP.

        When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically
        convert them into FSDP-TP strided or non-strided shards depending on the current
        sharding dimension and factor of the DTensor. When the sharding dimension of FSDP
        matches that of TP, FSDP uses a _StridedShard placement type instead of Shard.
        This experimental FSDP-TP logic presides in this FSDP2 initialization function:
        ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param``

        Parameters
        ----------
        tp_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"].
            Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP.
        weight_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required
            when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor
            parameters and if the DTensor DeviceMesh includes dimensions that do not
            shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard).
            For example:
                - device_mesh["dp"] for FSDP.
                - device_mesh["dp_cp"] if using CP ranks in FSDP.
                - device_mesh["dp_shard"] if using HSDP ("dp_replicate", "dp_shard").
                - device_mesh["tp"] if using TP.
                - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
        """

Details

DTensor Lifecycle in TransformerEngine

  • Initialization
    • __init__
      • TransformerEngine model parameters are initialized either on device or meta device with the appropriate tp_size and TP sharding strategy, e.g. parallel_mode and sequence_parallel.
    • TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)
      • Converts parameters to DTensor with appropriate TP placement(s) based on the TP sharding strategy specified in __init__, using transformer_engine.pytorch.distributed._convert_param_to_dtensor_param.
        • tp_mesh is a 1-D DeviceMesh containing the TP ProcessGroup that will be registered with the TransformerEngine module.
        • weight_mesh is the 1-D DeviceMesh containing the ProcessGroup that shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes like Float8CurrentScaling.
      • Needs to be invoked prior to fully_shard (which responds to the TP placements) and prior to reset_parameters(defer_init=False), which quantizes parameters.
      • Can also be directly invoked during __init__(tp_mesh, weight_mesh) for supported TransformerEngine modules.
    • fully_shard shards the TransformerEngine model with FSDP2.
      • When fully_shard encounters TP sharding on dim=0, it will use a _StridedShard for DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in the DeviceMesh and DTensor.placements. (See Appendix for visualization of this sharding strategy.)
    • reset_parameters is called if using meta device initialization.
  • Training
    • Pre-forward, FSDP2 all-gathers the sharded DTensor "main" weight that it registered during fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such as FusedAdam must be used to properly handle high-precision main weights.)
      • When using FSDP2 x TP, the all-gathered Tensor is actually a TP-sharded DTensor, which deviates from the original FSDP2 paradigm where the all-gathered Tensor is fully-unsharded and the DTensor wrapping is discarded. To support these DTensor compute weights in TransformerEngine modules, we utilize transformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensor to localize the DTensor and also inherit requires_grad attribute from the DTensor parameter as the local Tensor has this un-set during DTensor.from_local(Tensor) for FP8 parameters specifically!
    • Post-backward, the Tensor gradient is converted and attached to the DTensor.grad attribute.
      • NOTE(@cspades, @vthumbe1503): For some reason, FusibleOperation (RMSNorm and LayerNorm) require casting the gradient from Tensor to a DTensor matching the configuration of the DTensor weights. I have confirmed the gradient is installed correctly on RMSNorm weights (same shape and sharding configuration as the sharded optimizer state), and it will not affect normal TransfomerEngine operations, but it is not totally clear why this is necessary with FSDP2 x TP.

Bugs

  • Fix bug where "shard" was the presumed weight sharding sub-mesh in the DTensor.device_mesh. Now, users can precisely specify their own custom weight-sharding DeviceMesh for per-tensor amax_reduction_group via the set_device_mesh(weight_mesh) API.
  • TransformerEngineBaseModule: self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}

Testing

# TransformerEngine Main
[Rank 0] (after 1 iterations) memory (MB) | allocated: 23511.65 | max allocated: 25189.68 | reserved: 25678.00 | max reserved: 25678.00
 [2026-03-02 09:55:17.189564] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12715.7 | throughput per GPU (TFLOP/s/GPU): 530.6 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124915E+00 | loss scale: 1.0 | grad norm: 5.474 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:55:29.768521] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12578.7 | throughput per GPU (TFLOP/s/GPU): 536.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.143806E+00 | loss scale: 1.0 | grad norm: 5.366 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Post-DCP Modifications (This PR)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 23511.65 | max allocated: 29783.24 | reserved: 25678.00 | max reserved: 31510.00
 [2026-03-02 09:29:36.550070] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12556.5 | throughput per GPU (TFLOP/s/GPU): 537.3 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124463E+00 | loss scale: 1.0 | grad norm: 5.471 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:29:49.216068] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12665.7 | throughput per GPU (TFLOP/s/GPU): 532.7 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.142863E+00 | loss scale: 1.0 | grad norm: 5.355 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • NOTE(@cspades): DelayedScaling has DCP save/load disparity issues, i.e. on the scale of +/-1 to the uint8 parameter checkpoint!

Appendix

_StridedShard - Using FSDP2 x TP Strided-Sharding

# (DP=4, TP=2)
(_StridedShard(dim=0, sf=2), Shard(dim=0))

┌───┬───┐
│ 0 │ 4 │ ← DP=0
├───┼───┤
│ 1 │ 5 │ ← DP=1
├───┼───┤          FSDP all-gather happens across the DP ranks,
│ 2 │ 6 │ ← DP=2   so we need to form the 0-3 and 4-7 TP shards!
├───┼───┤
│ 3 │ 7 │ ← DP=3
└───┴───┘
  ↑   ↑
TP=0 TP=1

When redistribute'ing a global DTensor to (_StridedShard(dim=0, sf=2), Shard(dim=0)), DTensor will perform the following steps:

  • Pre-shard the Tensor data with respect to the stride / shard factor, which is defined as the product of the parallelism sizes of all Shard placements to the right of _StridedShard. (In the above example, since TP=2, the factor is 2.)
    • [0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].
    • In the context of this PR and fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling _convert_param_to_dtensor_param!
  • Shard the pre-shards for _StridedShard.
    • [0] [1] [2] [3] and [4] [5] [6] [7]
  • Concatenate the strided shards.
    • [0 4] [1 5] [2 6] [3 7], which are assigned to the _StridedShard ranks.
    • Note that this is very different if we did left-to-right-sharding, which would have given us [0 1] [2 3] [4 5] [6 7]!
  • Subsequently / finally, each strided shard is sharded on the Shard placement.
    • [0] [4] / [1] [5] / [2] [6] / [3] [7], which are assigned to the Shard ranks.
    • Note that this is very different if we did left-to-right sharding, which would have given us [0] [1] / [2] [3] / [4] [5] / [6] [7]!

PyTorch also supports the inverse / un-sharding of this redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR adds full DCP (Distributed Checkpoint) compatibility for FSDP2 × TP strided sharding across all TransformerEngineBaseModules by introducing a set_device_mesh(tp_mesh, weight_mesh) API that converts module parameters to appropriately-placed DTensors prior to fully_shard. It also adds a Float8 / MXFP8 all-gather guard for DTensor-wrapped out buffers, and fixes the amax_reduction_group selection (previously hardcoded to a "shard" mesh dimension).

Key changes:

  • New set_device_mesh API on Linear, LayerNormLinear, LayerNormMLP, MultiheadAttention, DotProductAttention, TransformerLayer, LayerNorm, and RMSNorm — converts parameters to TP-sharded or Replicated DTensors before fully_shard so FSDP2 can build the correct _StridedShard placements.
  • _convert_param_to_dtensor_param / _extract_trainable_tensor_from_dtensor helpers in distributed.py centralise DTensor conversion and requires_grad propagation.
  • DTensor-aware forward / backward in ops/basic/layer_norm.py and ops/basic/rmsnorm.py — extracts local tensors for CUDA kernels and re-wraps weight gradients as DTensors so FSDP2 gradient accumulation works correctly.
  • Bug fix in _LayerNormMLP backward: the pre-existing isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) check was always False (quantizer objects are never QuantizedTensorStorage); the corrected check isinstance(ctx.fc1_weight, QuantizedTensorStorage) properly guards the update_usage(columnwise_usage=True) call.
  • DCP checkpoint round-trip test added in run_fsdp2_model.py with AppState save / load, pre-save vs post-load loss parity, and model / optimizer state-dict comparison.
  • Known limitations: GroupedLinear and FusibleOperation DCP support are explicitly deferred; DelayedScaling has a ±1 uint8 parity issue at checkpoint boundaries (noted in PR, test skipped for that recipe).

Confidence Score: 3/5

  • Functional parity tests pass on Llama 8B, but the feature is broad and several edge cases (GroupedLinear, DelayedScaling checkpointing, post-init set_device_mesh ordering) are deferred or known-broken.
  • The core FSDP2-TP DTensor conversion logic and DCP checkpoint test are sound and backed by Megatron CI parity runs. However, the pre-flagged args.sharding_dims None-guard issues in the test runner remain, the DelayedScaling checkpoint has an acknowledged parity disparity, GroupedLinear support is explicitly incomplete, and the _set_tensor_parallel_attributes / set_tensor_model_parallel_attributes assertion ordering creates a fragility for post-init set_device_mesh calls that the API permits but the implementation does not fully protect against.
  • tests/pytorch/distributed/run_fsdp2_model.py (pre-flagged None-guard issues at lines 116 and 379 still present), transformer_engine/pytorch/module/linear.py + layernorm_linear.py + layernorm_mlp.py (_set_tensor_parallel_attributes assertion fragility when set_device_mesh is used post-init).

Important Files Changed

Filename Overview
tests/pytorch/distributed/run_fsdp2_model.py Adds DCP checkpoint save/load test with AppState, but retains the pre-flagged args.sharding_dims None-guard and f-string issues at lines 116 and 379; test structure and autocast guards are otherwise correct.
transformer_engine/pytorch/distributed.py Adds _convert_param_to_dtensor_param and _extract_trainable_tensor_from_dtensor helpers; logic is correct and well-documented.
transformer_engine/pytorch/module/layernorm_mlp.py Adds full set_device_mesh with TP DTensor conversions; fixes a pre-existing backward bug (ctx.fc1_weight_quantizerctx.fc1_weight in isinstance check); unconditional bias DTensor conversion when use_bias=False is a known deferred issue.
transformer_engine/pytorch/module/layernorm_linear.py Adds set_device_mesh and refactors _set_tensor_parallel_attributes out of reset_parameters; new _get_bias_tensors and _get_layernorm_weight_and_bias helpers correctly extract DTensor local tensors.
transformer_engine/pytorch/module/linear.py Adds set_device_mesh with TP placement logic and refactors weight/bias tensor extraction with DTensor support; return type of _get_weight_and_bias_tensors changed from Tuple to List (all callers use unpacking, so no runtime impact).
transformer_engine/pytorch/ops/basic/layer_norm.py Adds DTensor local-extraction in forward/backward and correctly wraps grad_weight/grad_bias as DTensors in backward to satisfy FSDP2 gradient accumulation requirements for Replicate-placed parameters.
transformer_engine/pytorch/tensor/float8_tensor.py Adds guard to extract _local_tensor from DTensor out before FP8 all-gather post-processing, since to_local() is unsupported under Torch Dispatch for quantized tensors.

Sequence Diagram

sequenceDiagram
    participant U as User Code
    participant TEM as TEModule.__init__
    participant SDM as set_device_mesh()
    participant CP as _convert_param_to_dtensor_param()
    participant FS as fully_shard()
    participant RP as reset_parameters()
    participant FWD as Forward Pass
    participant DCP as torch.distributed.checkpoint

    U->>TEM: TEModule(tp_size=N, tp_mesh=mesh["tp"], weight_mesh=mesh["dp_shard","tp"])
    TEM->>TEM: register_parameter() — plain tensors
    TEM->>TEM: init_fp8_metadata()
    TEM->>SDM: set_device_mesh(tp_mesh, weight_mesh)
    SDM->>CP: convert weight → DTensor(Shard(dim=0))
    SDM->>CP: convert bias → DTensor(Shard(dim=0) or Replicate)
    SDM->>SDM: set amax_reduction_group on Float8CurrentScalingQuantizer
    TEM->>RP: reset_parameters(defer_init=False or True)
    U->>FS: fully_shard(model, mesh=mesh["dp_replicate","dp_shard"])
    Note over FS: Sees DTensor Shard(dim=0) on TP dim →<br/>builds _StridedShard(dim=0, sf=tp_size) × Shard(dim=0)
    U->>RP: reset_parameters() [meta device only]
    loop Training
        FWD->>FWD: _extract_trainable_tensor_from_dtensor(weight)
        FWD->>FWD: C++ kernel on local tensor
        FWD->>FWD: attach DTensor.grad post-backward
    end
    U->>DCP: torch.distributed.checkpoint.save(AppState)
    DCP->>DCP: get_state_dict() — evict _extra_state
    U->>DCP: torch.distributed.checkpoint.load(AppState)
    DCP->>DCP: set_state_dict(strict=False)
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/module/layernorm_mlp.py, line 1451-1453 (link)

    Correct bug fix — ctx.fc1_weight_quantizerctx.fc1_weight

    The original code checked isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage), but fc1_weight_quantizer is a Quantizer object, not a QuantizedTensorStorage, so that condition was always False — meaning ctx.fc1_weight.update_usage(columnwise_usage=True) was silently never called. The new check correctly tests whether the weight itself is a QuantizedTensorStorage, matching the intent of the guard. This is a real latent bug surfaced and fixed by this PR.

  2. transformer_engine/pytorch/module/linear.py, line 1343-1345 (link)

    Potential assertion failure when set_device_mesh is called post-initialization

    _set_tensor_parallel_attributes calls set_tensor_model_parallel_attributes, which asserts that none of ("tensor_model_parallel", "partition_dim", "partition_stride") are already set on the tensor. This is safe when set_device_mesh is called inside __init__ (before reset_parameters runs), because _convert_param_to_dtensor_param copies param.__dict__ and the plain params do not yet carry TP attributes.

    However, if a user calls set_device_mesh after a non-meta-device __init__ that has already run reset_parameters (the external-call usage pattern described in the PR docs), the converted DTensor would inherit the TP attributes from the source param. Any subsequent reset_parameters() call would then hit the assertion:

    assert not hasattr(tensor, attribute)   # fails — attribute already set
    

    The same risk exists in LayerNormLinear._set_tensor_parallel_attributes and LayerNormMLP._set_tensor_parallel_attributes.

    Consider either (a) skipping set_tensor_model_parallel_attributes when the tensor already has TP attributes, or (b) adding a clear guard that set_device_mesh must be called before reset_parameters with an explicit error rather than a silent wrong-order assertion failure.

Last reviewed commit: 439b1aa

@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from 4ec2947 to dbb9d14 Compare March 4, 2026 18:10
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from fcdd5bd to c912f5b Compare March 5, 2026 16:06
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from c912f5b to 2aadb35 Compare March 5, 2026 18:30
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 3 times, most recently from a7a17c2 to bc82f02 Compare March 6, 2026 17:02
@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants